import numpy as np


def get_noise_schedule(config):
    noise_profile = config["gfn"]["noise_exploration"]["noise_profile"]
    assert noise_profile in ["linear", "quadratic_convex", "quadratic_concave", "cosine_annealing", "exponential", "exponential_flat", "none"], "Invalid noise profile"
    if noise_profile == "linear":
        return linear_schedule(config["gfn"]["noise_exploration"]["initial_noise"], config["gfn"]["noise_exploration"]["final_noise"], config["n_iterations"]+1)
    elif noise_profile == "quadratic_convex":
        return quadratic_convex_schedule(config["gfn"]["noise_exploration"]["initial_noise"], config["gfn"]["noise_exploration"]["final_noise"], config["n_iterations"]+1)
    elif noise_profile == "quadratic_concave":
        return quadratic_concave_schedule(config["gfn"]["noise_exploration"]["initial_noise"], config["gfn"]["noise_exploration"]["final_noise"], config["n_iterations"]+1)
    elif noise_profile == "cosine_annealing":
        return cosine_annealing_schedule(config["gfn"]["noise_exploration"]["initial_noise"], config["gfn"]["noise_exploration"]["final_noise"], config["n_iterations"]+1)
    elif noise_profile == "exponential":
        return exponential_schedule(config["gfn"]["noise_exploration"]["initial_noise"], config["gfn"]["noise_exploration"]["final_noise"], config["n_iterations"]+1)
    elif noise_profile == "exponential_flat":
        return exponential_flat_schedule(config["gfn"]["noise_exploration"]["initial_noise"], config["gfn"]["noise_exploration"]["final_noise"], config["n_iterations"]+1)
    elif noise_profile == "none":
        return np.full(config["n_iterations"]+1, 0)

def linear_schedule(initial, final, iterations):
    return np.linspace(initial, final, iterations + 1)

def quadratic_concave_schedule(initial, final, iterations):
    return ((final - initial) / iterations ** 2) * np.arange(iterations + 1) ** 2 + initial

def quadratic_convex_schedule(initial, final, iterations):
    return -((final - initial) / iterations ** 2) * np.arange(iterations + 1)[::-1] ** 2 + final

def single_cosine_annealing_schedule(initial, final, iterations):
    return final + (initial - final) * 0.5 * (1 + np.cos(np.pi * np.arange(iterations + 1) / iterations))

def cosine_annealing_schedule(initial, final, iterations):
    number_of_slices = 6
    start_noise = [final + initial / (2 ** i) for i in range(number_of_slices)]
    start_noise[-1] = final
    end_noise = np.full(number_of_slices, final)
    # perform cosine annealing for each slice
    return np.concatenate([single_cosine_annealing_schedule(start, end, iterations // number_of_slices) for start, end in zip(start_noise, end_noise)])

def exponential_schedule(initial, final, iterations):
    gamma = np.exp(-2 * np.exp(1) / iterations) 
    return (initial - final) * (gamma ** np.arange(iterations + 1) - gamma ** (iterations + 1)) + final

def exponential_flat_schedule(initial, final, iterations):
    # Undergo an exponential decay for the first half of the iterations, then be final for the second half
    gamma = np.exp(-2 * np.exp(1) / (iterations/2))
    return np.concatenate([(initial - final) * (gamma ** np.arange(iterations//2 + 1) - gamma ** (iterations//2 + 1)) + final, np.full(iterations//2 + 1, final)])